#!/usr/bin/env python3
# Copyright (c) 2025 Huawei Device Co., Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import re
import sys
from functools import cache
from pathlib import Path
from subprocess import check_call

from taihe.utils.resources import ResourceLocator, ResourceType

ANTLR_PKG = "taihe.parse.antlr"
CURRENT_DIR = Path(__file__).parent.resolve()
ANTLR_PATH = CURRENT_DIR / ANTLR_PKG.replace(".", "/")
G4_FILE = CURRENT_DIR / "Taihe.g4"

# HACK: The parent module `taihe.parse` imports the code which will be generated by us soon.
# We directly import ANTLR-generated module, without initializing the parent module.
sys.path.insert(0, str(ANTLR_PATH))


@cache
def get_parser():
    from TaiheParser import TaiheParser

    return TaiheParser


def get_hint(attr_kind):
    if attr_kind.endswith("Lst"):
        return f'List["TaiheAST.{attr_kind[:-3]}"]'
    if attr_kind.endswith("Opt"):
        return f'Optional["TaiheAST.{attr_kind[:-3]}"]'
    return f'"TaiheAST.{attr_kind}"'


def get_attr_pairs(ctx):
    for attr_full_name, attr_ctx in ctx.__dict__.items():
        if not attr_full_name.startswith("_") and attr_full_name != "parser":
            yield attr_full_name.split("_", 1)


def snake_case(name):
    """Convert CamelCase to snake_case."""
    return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()


class Inspector:
    def __init__(self):
        self.parentCtx = None
        self.invokingState = None
        self.children = None
        self.start = None
        self.stop = None


def generate_ast():
    with open(ANTLR_PATH / "TaiheAST.py", "w") as file:
        file.write(
            f"from dataclasses import dataclass\n"
            f"from typing import Any, Union, List, Optional\n"
            f"\n"
            f"from taihe.utils.sources import SourceLocation\n"
            f"\n"
            f"\n"
            f"class TaiheAST:\n"
            f"    @dataclass(kw_only=True)\n"
            f"    class any:\n"
            f"        loc: SourceLocation\n"
            f"\n"
            f"        def _accept(self, visitor) -> Any:\n"
            f"            raise NotImplementedError()\n"
            f"\n"
            f"\n"
            f"    @dataclass\n"
            f"    class TOKEN(any):\n"
            f"        text: str\n"
            f"\n"
            f"        def __str__(self):\n"
            f"            return self.text\n"
            f"\n"
            f"        def _accept(self, visitor) -> Any:\n"
            f"            return visitor.visit_token(self)\n"
            f"\n"
        )
        parser = get_parser()
        type_list = []
        for rule_name in parser.ruleNames:
            node_kind = rule_name[0].upper() + rule_name[1:]
            ctx_kind = node_kind + "Context"
            ctx_type = getattr(parser, ctx_kind)
            type_list.append((node_kind, ctx_type))
        for node_kind, ctx_type in type_list:
            subclasses = ctx_type.__subclasses__()
            if subclasses:
                file.write(f"    {node_kind} = Union[\n")
                for sub_type in subclasses:
                    sub_kind = sub_type.__name__
                    attr_kind = sub_kind[:-7]
                    attr_hint = get_hint(attr_kind)
                    type_list.append((attr_kind, sub_type))
                    file.write(f"        {attr_hint},\n")
                file.write(f"    ]\n" f"\n")
            else:
                ctx = ctx_type(None, Inspector())
                file.write(f"    @dataclass\n" f"    class {node_kind}(any):\n")
                for attr_kind, attr_name in get_attr_pairs(ctx):
                    attr_hint = get_hint(attr_kind)
                    file.write(f"        {attr_name}: {attr_hint}\n")
                file.write(
                    f"\n"
                    f"        def _accept(self, visitor) -> Any:\n"
                    f"            return visitor.visit_{snake_case(node_kind)}(self)\n"
                    f"\n"
                )


def generate_visitor():
    with open(ANTLR_PATH / "TaiheVisitor.py", "w") as file:
        file.write(
            f"from {ANTLR_PKG}.TaiheAST import TaiheAST\n"
            f"\n"
            f"from typing import Any\n"
            f"\n"
            f"\n"
            f"class TaiheVisitor:\n"
            f"    def visit(self, node: TaiheAST.any) -> Any:\n"
            f"        return node._accept(self)\n"
            f"\n"
            f"    def visit_token(self, node: TaiheAST.TOKEN) -> Any:\n"
            f"        raise NotImplementedError()\n"
            f"\n"
        )
        parser = get_parser()
        type_list = []
        for rule_name in parser.ruleNames:
            node_kind = rule_name[0].upper() + rule_name[1:]
            ctx_kind = node_kind + "Context"
            ctx_type = getattr(parser, ctx_kind)
            type_list.append((node_kind, ctx_type))
        for node_kind, ctx_type in type_list:
            subclasses = ctx_type.__subclasses__()
            if subclasses:
                for sub_type in subclasses:
                    sub_kind = sub_type.__name__
                    attr_kind = sub_kind[:-7]
                    file.write(
                        f"    def visit_{snake_case(attr_kind)}(self, node: TaiheAST.{attr_kind}) -> Any:\n"
                        f"        return self.visit_{snake_case(node_kind)}(node)\n"
                        f"\n"
                    )
            file.write(
                f"    def visit_{snake_case(node_kind)}(self, node: TaiheAST.{node_kind}) -> Any:\n"
                f"        raise NotImplementedError()\n"
                f"\n"
            )


def run_antlr():
    locator = ResourceLocator.detect()
    jar = locator.get(ResourceType.DEV_ANTLR)
    args = ["java", "-cp", str(jar), "org.antlr.v4.Tool"]
    args += ["-Dlanguage=Python3", "-no-listener", G4_FILE, "-o", ANTLR_PATH]
    check_call(args, env={})


def main():
    run_antlr()
    generate_ast()
    generate_visitor()


if __name__ == "__main__":
    main()
